{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "UUXnh11hA75x"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2023 The IREE Authors"
      ],
      "metadata": {
        "id": "UUXnh11hA75x"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "# See https://llvm.org/LICENSE.txt for license information.\n",
        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
      ],
      "metadata": {
        "cellView": "form",
        "id": "FqsvmKpjBJO2"
      },
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png\" height=\"20px\"> PyTorch Ahead-of-time (AOT) export workflows using <img src=\"https://raw.githubusercontent.com/iree-org/iree/main/docs/website/docs/assets/images/IREE_Logo_Icon_Color.svg\" height=\"20px\"> IREE\n",
        "\n",
        "This notebook shows how to use [iree-turbine](https://github.com/iree-org/iree-turbine) for export from a PyTorch session to [IREE](https://github.com/iree-org/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.\n",
        "\n",
        "iree-turbine contains both a \"simple\" AOT exporter and an underlying advanced\n",
        "API for complicated models and full feature availability. This notebook only\n",
        "uses the \"simple\" exporter."
      ],
      "metadata": {
        "id": "38UDc27KBPD1"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Setup"
      ],
      "metadata": {
        "id": "jbcW5jMLK8gK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "#@title Uninstall existing packages\n",
        "#   This avoids some warnings when installing specific PyTorch packages below.\n",
        "!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision"
      ],
      "metadata": {
        "id": "KsPubQSvCbXd",
        "cellView": "form"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Install Pytorch 2.3.0 (for CPU)\n",
        "!python -m pip install --index-url https://download.pytorch.org/whl/test/cpu --upgrade torch==2.3.0"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3ebWazfjJ6en",
        "outputId": "913c0ccf-9895-4c3f-cd12-20c8185c747d",
        "cellView": "form"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Looking in indexes: https://download.pytorch.org/whl/test/cpu\n",
            "Collecting torch==2.3.0\n",
            "  Downloading https://download.pytorch.org/whl/test/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl (190.4 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.4/190.4 MB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.16.1)\n",
            "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (4.12.2)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (1.13.1)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (3.1.4)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.3.0) (2024.10.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.3.0) (3.0.2)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.3.0) (1.3.0)\n",
            "Installing collected packages: torch\n",
            "  Attempting uninstall: torch\n",
            "    Found existing installation: torch 2.5.0+cu121\n",
            "    Uninstalling torch-2.5.0+cu121:\n",
            "      Successfully uninstalled torch-2.5.0+cu121\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "timm 1.0.11 requires torchvision, which is not installed.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed torch-2.3.0+cpu\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4iJFDHbsAzo4",
        "outputId": "5ba8936d-0cf4-40e8-d012-d987ee9ad9e3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting iree-turbine\n",
            "  Downloading iree_turbine-2.5.0-py3-none-any.whl.metadata (5.7 kB)\n",
            "Requirement already satisfied: numpy>=1.26.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (1.26.4)\n",
            "Collecting iree-compiler (from iree-turbine)\n",
            "  Downloading iree_compiler-20241104.1068-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (615 bytes)\n",
            "Collecting iree-runtime (from iree-turbine)\n",
            "  Downloading iree_runtime-20241104.1068-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (798 bytes)\n",
            "Requirement already satisfied: torch>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (2.3.0+cpu)\n",
            "Requirement already satisfied: Jinja2>=3.1.3 in /usr/local/lib/python3.10/dist-packages (from iree-turbine) (3.1.4)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.1.3->iree-turbine) (3.0.2)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (3.16.1)\n",
            "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (4.12.2)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (1.13.1)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (3.4.2)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2.3.0->iree-turbine) (2024.10.0)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2.3.0->iree-turbine) (1.3.0)\n",
            "Downloading iree_turbine-2.5.0-py3-none-any.whl (271 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m271.3/271.3 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading iree_compiler-20241104.1068-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (70.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m70.7/70.7 MB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading iree_runtime-20241104.1068-cp310-cp310-manylinux_2_28_x86_64.whl (8.0 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.0/8.0 MB\u001b[0m \u001b[31m52.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: iree-runtime, iree-compiler, iree-turbine\n",
            "Successfully installed iree-compiler-20241104.1068 iree-runtime-20241104.1068 iree-turbine-2.5.0\n"
          ]
        }
      ],
      "source": [
        "#@title Install iree-turbine\n",
        "\n",
        "!python -m pip install iree-turbine"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Report version information\n",
        "!echo \"Installed iree-turbine, $(python -m pip show iree_turbine | grep Version)\"\n",
        "\n",
        "!echo -e \"\\nInstalled IREE, compiler version information:\"\n",
        "!iree-compile --version\n",
        "\n",
        "import torch\n",
        "print(\"\\nInstalled PyTorch, version:\", torch.__version__)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nkVLzRpcDnVL",
        "outputId": "0481a14b-e261-4ea5-8d4e-228078b775f6"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Installed iree-turbine, Version: 2.5.0\n",
            "\n",
            "Installed IREE, compiler version information:\n",
            "IREE (https://iree.dev):\n",
            "  IREE compiler version 20241104.1068 @ 9c85e30df30d6efcf68a7a1b594e89322bd6085d\n",
            "  LLVM version 20.0.0git\n",
            "  Optimized build\n",
            "\n",
            "Installed PyTorch, version: 2.3.0+cpu\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Sample AOT workflow\n",
        "\n",
        "1. Define a program using `torch.nn.Module`\n",
        "2. Export the program using `aot.export()`\n",
        "3. Compile to a deployable artifact\n",
        "  * a: By staying within a Python session\n",
        "  * b: By outputting MLIR and continuing using native tools\n",
        "\n",
        "Useful documentation:\n",
        "\n",
        "* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation\n",
        "* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)"
      ],
      "metadata": {
        "id": "1Mi3YR75LBxl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 1. Define a program using `torch.nn.Module`\n",
        "torch.manual_seed(0)\n",
        "\n",
        "class LinearModule(torch.nn.Module):\n",
        "  def __init__(self, in_features, out_features):\n",
        "    super().__init__()\n",
        "    self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))\n",
        "    self.bias = torch.nn.Parameter(torch.randn(out_features))\n",
        "\n",
        "  def forward(self, input):\n",
        "    return (input @ self.weight) + self.bias\n",
        "\n",
        "linear_module = LinearModule(4, 3)"
      ],
      "metadata": {
        "id": "oPdjrmPZMNz6"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 2. Export the program using `aot.export()`\n",
        "import iree.turbine.aot as aot\n",
        "\n",
        "example_arg = torch.randn(4)\n",
        "export_output = aot.export(linear_module, example_arg)"
      ],
      "metadata": {
        "id": "eK2fWVfiSQ8f"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 3a. Compile fully to a deployable artifact, in our existing Python session\n",
        "\n",
        "# Staying in Python gives the API a chance to reuse memory, improving\n",
        "# performance when compiling large programs.\n",
        "\n",
        "compiled_binary = export_output.compile(save_to=None)\n",
        "\n",
        "# Use the IREE runtime API to test the compiled program.\n",
        "import numpy as np\n",
        "import iree.runtime as ireert\n",
        "\n",
        "config = ireert.Config(\"local-task\")\n",
        "vm_module = ireert.load_vm_module(\n",
        "    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),\n",
        "    config,\n",
        ")\n",
        "\n",
        "input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)\n",
        "result = vm_module.main(input)\n",
        "print(result.to_host())"
      ],
      "metadata": {
        "id": "eMRNdFdos900",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "073486bf-1590-4838-ebc8-b76b80e87d32"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[ 1.4178505 -1.2343317 -7.4767942]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 3b. Output MLIR then continue from Python or native tools later\n",
        "\n",
        "# Leaving Python allows for file system checkpointing and grants access to\n",
        "# native development workflows.\n",
        "\n",
        "mlir_file_path = \"/tmp/linear_module_pytorch.mlirbc\"\n",
        "vmfb_file_path = \"/tmp/linear_module_pytorch_llvmcpu.vmfb\"\n",
        "\n",
        "print(\"Exported .mlir:\")\n",
        "export_output.print_readable()\n",
        "export_output.save_mlir(mlir_file_path)\n",
        "\n",
        "print(\"Compiling and running...\")\n",
        "!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=host {mlir_file_path} -o {vmfb_file_path}\n",
        "!iree-run-module --module={vmfb_file_path} --device=local-task --input=\"4xf32=[1.0, 2.0, 3.0, 4.0]\""
      ],
      "metadata": {
        "id": "0AdkXY8VNL2-",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3381872c-d048-4f99-b35d-a0ebfef00dfc"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Exported .mlir:\n",
            "module @module {\n",
            "  func.func @main(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> attributes {torch.assume_strict_symbolic_shapes} {\n",
            "    %int0 = torch.constant.int 0\n",
            "    %0 = torch.aten.unsqueeze %arg0, %int0 : !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>\n",
            "    %1 = torch.vtensor.literal(dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>) : !torch.vtensor<[4,3],f32>\n",
            "    %2 = torch.aten.mm %0, %1 : !torch.vtensor<[1,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[1,3],f32>\n",
            "    %int0_0 = torch.constant.int 0\n",
            "    %3 = torch.aten.squeeze.dim %2, %int0_0 : !torch.vtensor<[1,3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
            "    %4 = torch.vtensor.literal(dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32>) : !torch.vtensor<[3],f32>\n",
            "    %int1 = torch.constant.int 1\n",
            "    %5 = torch.aten.add.Tensor %3, %4, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>\n",
            "    return %5 : !torch.vtensor<[3],f32>\n",
            "  }\n",
            "}\n",
            "\n",
            "{-#\n",
            "  dialect_resources: {\n",
            "    builtin: {\n",
            "      torch_tensor_4_3_torch.float32: \"0x040000005C3FC53F503C96BE49710BC0B684113FA1D18ABF2D05B3BF7A83CE3EE588563F442138BF0B83CEBE18BD18BFC6673A3E\",\n",
            "      torch_tensor_3_torch.float32: \"0x04000000074F5BBF99E08C3FAB1C89BF\"\n",
            "    }\n",
            "  }\n",
            "#-}\n",
            "Compiling and running...\n",
            "EXEC @main\n",
            "result[0]: hal.buffer_view\n",
            "3xf32=1.41785 -1.23433 -7.47679\n"
          ]
        }
      ]
    }
  ]
}